from .demo import *

guidance_prompt = """
### Guidance: 
    First generate abstract sequence, further convert into waypoints"""

action_prompt = """
### Actions for execution:
    [MOVE_X] target_x (grab)
    [MOVE_Y] target_y (grab)
    [MOVE_Z] target_z (grab)
    [GRIP] 
    [PUSH] target_x target_y target_z (grab)
    [PULL] target_x target_y target_z (grab)"""

############################################################################################################


demos = {
    "door-open-variant-v2": door_open_demo_prompt,
    "drawer-open-variant-v2": drawer_open_demo_prompt,
    "drawer-close-variant-v2": drawer_close_demo_prompt,
    "button-press-variant-v2": button_press_demo_prompt,
    "window-close-variant-v2": window_close_demo_prompt,
    "window-open-variant-v2": window_open_demo_prompt,
    "faucet-open-variant-v2": faucet_open_demo_prompt,
    "push-variant-v2": push_demo_prompt,
    "pick-place-variant-v2": pick_place_demo_prompt,
    "peg-insert-side-variant-v2": peg_insert_side_demo_prompt,
    "button-drawer-puck-stick-variant-v2": button_drawer_puck_stick_demo_prompt,
    "drawer-puck-stick-button-variant-v2": drawer_puck_stick_button_demo_prompt,
    "puck-drawer-button-stick-variant-v2": puck_drawer_button_stick_demo_prompt,
}


def get_demonstrations(task_name):
    prompt = ""
    for i in range(len(demos[task_name])):
        prompt += f"## Demonstrations {i + 1}\n"
        prompt += f"{demos[task_name][i]}\n\n"

    return prompt


task_instructions = {
    "metaworld": {
        "door-open-variant-v2": "Door Open",
        "drawer-open-variant-v2": "Drawer Open",
        "drawer-close-variant-v2": "Drawer Close",
        "window-open-variant-v2": "Window Open",
        "window-close-variant-v2": "Window Close",
        "button-press-variant-v2": "Button Press",
        "peg-insert-side-variant-v2": "Insert Peg to Goal",
        "push-variant-v2": "Push Object to Goal",
        "faucet-open-variant-v2": "Faucet Open",
        "pick-place-variant-v2": "Pick and Place Object to Goal",
    },
    "metaworld_complex": {
        "puck": "Slide Puck to Goal",
        "drawer": "Close Drawer",
        "button": "Press Button",
        "stick": "Insert Peg to Goal",
    },
}


def get_task_instructions(domain_name, task_name):
    prompt = ""
    if domain_name == "metaworld":
        prompt = task_instructions[domain_name].get(task_name, "")
    if domain_name == "metaworld_complex":
        task_names = [t for t in task_name.split("-") if t not in ["variant", "v2"]]
        prompt = " and ".join([task_instructions[domain_name][t] for t in task_names])

    return prompt


speed_domain = """
    Given current position and target position, action is caculated with: 

        def get_delta_pos(self, curr_pos):
            target_pos = self.get_target_pos(curr_pos, self.target_value)
            target_pos = self.smooth_target_pos(curr_pos, target_pos)

            pid_pos = self._calculate_pid(curr_pos, target_pos)
            delta_pos = move(curr_pos, pid_pos, p=self.p)

            return delta_pos

    where move function is defined as:

        def move(from_xyz, to_xyz, p):
            error = to_xyz - from_xyz
            response = p * error

            if np.any(np.absolute(response) > 1.):
                warnings.warn('Constant(s) may be too high. Environments clip response to [-1, 1]')

            return response
    
            
    and _calculate_pid method is defined as:

        self.previous_error = np.zeros(3)
        self.integral = np.zeros(3)

        self.Kp = 0.65
        self.Ki = 0.01
        self.dt = 0.01

        def _calculate_pid(self, curr_pos, target_pos):
            error = target_pos - curr_pos
            P = self.Kp * error

            self.integral += error * self.dt
            I = self.Ki * self.integral

            self.previous_error = error
            pid_pos = curr_pos + P + I 
            return pid_pos

            
    and smooth_target_pos method is defined as:
        self.n_warmup_steps = warmup_steps
        self.curr_warmup_step = 1

        def smooth_target_pos(self, curr_pos, target_pos):
            if self.curr_warmup_step < self.n_warmup_steps:
                target_pos = curr_pos + (target_pos - curr_pos) * np.square(
                    self.curr_warmup_step / self.n_warmup_steps
                )
                self.curr_warmup_step += 1

            return target_pos

    Lastly speed is calculated as L2 norm of delta position
        speed = np.linalg.norm(delta_pos)

    Speed along an axis is calculated as L2 norm of delta position along that axis
        speed = np.linalg.norm(delta_pos[axis])
"""


domain_info = {
    "speed below": speed_domain,
    "speed above": speed_domain,
    "speed above and below": speed_domain,
    "x-axis faster": speed_domain,
    "y-axis faster": speed_domain,
    "faster": speed_domain,
    "slower": speed_domain,
}


def get_context_prompt(context, domain_name, task_name):
    print(domain_name, task_name)
    print(len(context))

    if domain_name == "metaworld":
        context_prompt = f"""
    ### Context Information:
        Other than the task instruction, you are also provided with additional context information.
        You should consider this information while generating the waypoints.

        {context[1].replace("a loss function", "waypoints")}

        To achieve the task, you should refer to following domain information:

        {domain_info[context[2]]}
    """
    else:
        context_prompt = f"""
    ### Context Information:
        Other than the task instruction, you are also provided with additional context information.
        You should consider this information while generating the waypoints.
    """
        for task, ctx in zip(task_name.split("-"), context):
            context_prompt += f"""
            For task {task}, you should consider following:

            {ctx[1].replace("a loss function", "waypoints")}
            """

        context_prompt += f"""
        To achieve the task, you should refer to following domain information:

            {domain_info[context[0][2]]}
    """

    return context_prompt
